# encoding: utf-8

import torch

import config

from model import Retriever

device = torch.device("cpu")
# dtype = "bfloat16"  # 109806336
dtype = torch.float32

model_cfg = config.RetrieverConfig_medium()
for cfg in [config.RetrieverConfig_small(), config.RetrieverConfig_medium(), config.RetrieverConfig_large()]:
    model = Retriever(device=device, ptdtype=dtype, config=cfg, flash=True)

    print(model.get_num_params())

"""
35326464
109806336
396053760

"""